1 Introduction

This R Markdown script analyses data from the PAL (probabilistic associative learning) task of the EMBA project. HGF parameters were extrated based on the subject-specific reaction times beforehand in MATLAB.

1.1 Some general settings

# number of simulations
nsim = 250

# set number of iterations and warmup for models
iter = 3000
warm = 1000

# set the seed
set.seed(2468)

1.2 Package versions

The following packages are used in this RMarkdown file:

## [1] "R version 4.5.1 (2025-06-13)"
## [1] "knitr version 1.50"
## [1] "ggplot2 version 4.0.0"
## [1] "brms version 2.22.0"
## [1] "designr version 0.1.13"
## [1] "bridgesampling version 1.1.2"
## [1] "tidyverse version 2.0.0"
## [1] "ggpubr version 0.6.1"
## [1] "ggrain version 0.0.4"
## [1] "bayesplot version 1.13.0"
## [1] "SBC version 0.3.0.9000"
## [1] "rstatix version 0.7.2"
## [1] "easystats version 0.7.5"
## [1] "BayesFactor version 0.9.12.4.7"
## [1] "bayestestR version 0.17.0"

1.3 Preparation

First, we load the parameters from the winning model.

# get HGF parameters  
df.hgf = read_csv(file.path("HGF_results/main", "eHGF-L21_results.csv")) %>%
  merge(., read_csv("../data/PAL-ADHD_data.csv", show_col_types = F) %>%
          select(subID, EDT, adhd.meds.bin) %>% distinct()) %>%
  mutate_if(is.character, as.factor)

# get belief state trajectories
df.trj = read_csv(file.path("HGF_results/main", "eHGF-L21_traj.csv"))

# extract the absolute changes in learning rate for the phases
df.upd = df.trj %>%
  select(subID, diagnosis, trl, alpha2, alpha3) %>% ungroup() %>%
  mutate(
    # code the phases > only take the beginning and end of volatile
    phase = case_when(
      trl < 73  ~ "pre",
      trl > 264 ~ "post",
      trl < 145 ~ "vol1",
      trl > 192 ~ "vol2"
    )
  ) %>%
  drop_na() %>%
  group_by(subID, diagnosis, phase) %>%
  summarise(
    alpha2 = median(alpha2),
    alpha3 = median(alpha3)
  ) %>%
  pivot_wider(names_from = phase, id_cols = c(subID, diagnosis), values_from = starts_with("alpha")) %>%
  group_by(subID, diagnosis) %>%
  summarise(
    alpha2_pre2vol  = abs(alpha2_pre  - alpha2_vol1),
    alpha2_vol2post = abs(alpha2_post - alpha2_vol2),
    alpha3_pre2vol  = abs(alpha3_pre  - alpha3_vol1),
    alpha3_vol2post = abs(alpha3_post - alpha3_vol2)
  ) %>% 
  pivot_longer(cols = starts_with("alpha")) %>%
  separate(name, into = c("level", "change")) %>%
  merge(., df.hgf %>% select(subID, EDT)) %>%
  mutate_if(is.character, as.factor)

# check whether there are LME differences between the diagnostic groups
kable(df.hgf %>% group_by(diagnosis) %>% shapiro_test(LME)) # all normally distributed
diagnosis variable statistic p
ADHD LME 0.9624505 0.5403480
BOTH LME 0.9732667 0.7853028
COMP LME 0.9722848 0.7633392
if (file.exists(file.path(brms_dir, "aov_lme.rds"))) {
  aov.lme = readRDS(file.path(brms_dir, "aov_lme.rds"))
} else {
  aov.lme = anovaBF(LME ~ diagnosis, data = df.hgf)
  saveRDS(aov.lme, file.path(brms_dir, "aov_lme.rds"))
}
aov.lme@bayesFactor
##                   bf        error                     time        code
## diagnosis -0.5840522 9.993788e-05 Thu Oct 30 11:19:46 2025 eca1873f721

There is anecdotal evidence against a difference in LME between diagnostic groups. This suggests that the eHGF model fit comparably well to the subjects of the different groups. Therefore, we move on to analyse its parameters.

The response model best fitting to our data was the one employed by Lawson et al. (2021): \[\log{RT} = \beta_0 + \beta_1 \times surprise_{stimulus} + \beta_2 \times pwPE + \beta_3 \times volatility_{phasic}\] Next, we use sum contrast coding for all of our categorical predictors.

# set and print the contrasts
contrasts(df.hgf$diagnosis) = contr.sum(3)
contrasts(df.hgf$diagnosis)
##      [,1] [,2]
## ADHD    1    0
## BOTH    0    1
## COMP   -1   -1
contrasts(df.hgf$adhd.meds.bin) = contr.sum(2)[c(2,1)]
contrasts(df.hgf$adhd.meds.bin)
##       [,1]
## FALSE   -1
## TRUE     1
contrasts(df.upd$diagnosis) = contr.sum(3)[c(2,1,3),]
contrasts(df.upd$diagnosis)
##      [,1] [,2]
## ADHD    0    1
## BOTH    1    0
## COMP   -1   -1
contrasts(df.upd$change) = contr.sum(2)
contrasts(df.upd$change)
##          [,1]
## pre2vol     1
## vol2post   -1
contrasts(df.upd$level) = contr.sum(2)
contrasts(df.upd$level)
##        [,1]
## alpha2    1
## alpha3   -1

2 H3c: second level tonic volatility

2.1 Model setup

# model formula
f.om2 = brms::bf( om2 ~ diagnosis )

# set weakly informative priors
priors = c(
  prior(normal(0, 4),  class = Intercept),
  prior(normal(0, 0.50),  class = sigma),
  prior(normal(0, 0.50),  class = b)
)

# change Intercept based on empirical priors used in the HGF model
priors = priors %>%
  mutate(
    prior = if_else(
      class == "Intercept", 
      gsub("\\(.*,", paste0("(", mean(df.hgf$om2mu), ", "), prior), prior),
    prior = if_else(
      class == "Intercept", 
      gsub(" .*\\)", paste0(" ", mean(df.hgf$om2sa), ")"), prior), prior)
  )

kable(priors)
prior class coef group resp dpar nlpar lb ub source
normal(-6.921, 8.7788) Intercept NA NA user
normal(0, 0.5) sigma NA NA user
normal(0, 0.5) b NA NA user

2.2 Posterior predictive checks

As the next step, we fit the model, check whether there are divergence or rhat issues, and then check whether the chains have converged.

# fit the final model
m.om2 = brm(f.om2, seed = 2288,
            df.hgf, prior = priors,
            iter = iter, warmup = warm,
            backend = "cmdstanr", threads = threading(t),
            file = file.path(brms_dir, "m_hgf_om2"),
            save_pars = save_pars(all = TRUE)
            )
rstan::check_hmc_diagnostics(m.om2$fit)
## 
## Divergences:
## 0 of 8000 iterations ended with a divergence.
## 
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
## 
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m.om2) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m.om2)
mcmc_trace(post.draws, regex_pars = "^b_",
           facet_args = list(ncol = 3)) +
  scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
  scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.

This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.

# get posterior predictions
post.pred = posterior_predict(m.om2, ndraws = nsim)

# check the fit of the predicted data compared to the real data
p1 = pp_check(m.om2, ndraws = nsim) + 
  theme_bw() + theme(legend.position = "none")

# distributions of means compared to the real values per group
p2 = ppc_stat_grouped(df.hgf$om2, post.pred, df.hgf$diagnosis) + 
  theme_bw() + theme(legend.position = "none")

p = ggarrange(p1, p2, 
          nrow = 2, ncol = 1, labels = "AUTO")
annotate_figure(p, top = text_grob("Posterior predictive checks", 
                                   face = "bold", size = 14))

Similar to above, the simulated data based on the model fits well with the real data, although it doesn’t reproduce the overall shape.

2.3 Inferences

Now that we are convinced that we can trust our model, we have a look at its estimate and use the hypothesis function to assess our hypotheses and perform explorative tests.

# print a summary
summary(m.om2)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: om2 ~ diagnosis 
##    Data: df.hgf (Number of observations: 66) 
##   Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
##          total post-warmup draws = 8000
## 
## Regression Coefficients:
##            Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept     -5.93      0.28    -6.48    -5.37 1.00     8326     6174
## diagnosis1     0.19      0.30    -0.39     0.78 1.00     7098     6186
## diagnosis2     0.15      0.30    -0.45     0.73 1.00     7272     5704
## 
## Further Distributional Parameters:
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     2.28      0.16     1.99     2.61 1.00     8008     5792
## 
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# get the estimates and compute group comparisons
df.m.om2 = post.draws %>% 
  select(starts_with("b_")) %>%
  mutate(
    ADHD   = b_Intercept + b_diagnosis1,
    BOTH   = b_Intercept + b_diagnosis2,
    COMP   = b_Intercept - b_diagnosis1 - b_diagnosis2,
    `h3c_ADHDvCOMP` = ADHD - COMP,
    `e1_BOTHvCOMP` = BOTH - COMP,
    `e2_ADHDvBOTH` = ADHD - BOTH,
  )

# plot the posterior distributions
df.m.om2 %>%
  select(ADHD, BOTH, COMP) %>%
  pivot_longer(cols = everything(), names_to = "coef", values_to = "estimate") %>%
  ggplot(aes(x = estimate, y = coef), fill = c_light) +
  geom_vline(xintercept = mean(df.m.om2$b_Intercept), linetype = 'dashed') +
  ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) + theme_bw() +
  theme(legend.position = "none")

# H3c: COMP != ADHD
h3c = hypothesis(m.om2, "0 < 2*diagnosis1 + diagnosis2")
h3c$hypothesis
##                          Hypothesis   Estimate Est.Error  CI.Lower  CI.Upper
## 1 (0)-(2*diagnosis1+diagnosis2) < 0 -0.5340246 0.5734074 -1.483274 0.4013223
##   Evid.Ratio Post.Prob Star
## 1   4.602241    0.8215
# Explore BOTH
e1  = hypothesis(m.om2, "0 < diagnosis1 + 2*diagnosis2", alpha = 0.025)
e1$hypothesis
##                          Hypothesis  Estimate Est.Error  CI.Lower CI.Upper
## 1 (0)-(diagnosis1+2*diagnosis2) < 0 -0.485513 0.5755258 -1.595381 0.632044
##   Evid.Ratio Post.Prob Star
## 1   4.076142     0.803
e2  = hypothesis(m.om2, "diagnosis1 > diagnosis2", alpha= 0.025)
e2$hypothesis
##                      Hypothesis   Estimate Est.Error   CI.Lower CI.Upper
## 1 (diagnosis1)-(diagnosis2) > 0 0.04851162 0.4891632 -0.9173199  1.01821
##   Evid.Ratio Post.Prob Star
## 1   1.145923     0.534
# equivalence
equivalence_test(df.m.om2 %>% select(starts_with("h") | starts_with("e")), 
                 range = rope_range(m.om2))
## # Test for Practical Equivalence
## 
##   ROPE: [-0.26 0.26]
## 
## Parameter     |        H0 | inside ROPE |       95% HDI
## -------------------------------------------------------
## h3c_ADHDvCOMP | Undecided |     24.05 % | [-0.59, 1.66]
## e1_BOTHvCOMP  | Undecided |     25.34 % | [-0.63, 1.60]
## e2_ADHDvBOTH  | Undecided |     42.87 % | [-0.92, 1.02]
# calculate effect sizes
df.effect = post.draws %>%
  mutate(across(starts_with("sd")|starts_with("sigma"), ~.^2)) %>%
  mutate(
    sumvar = sqrt(rowSums(select(., starts_with("sd")|starts_with("sigma")))),
    h3c = (2*`b_diagnosis1` + `b_diagnosis2`) / sumvar,
    e1  = (`b_diagnosis1` + 2*`b_diagnosis2`) / sumvar,
    e2  = -(-`b_diagnosis1` + `b_diagnosis2`) / sumvar
  )

kable(df.effect %>% select(starts_with("e")|starts_with("h")) %>%
        pivot_longer(cols = everything(), values_to = "estimate") %>%
        group_by(name) %>%
        summarise(
          ci.lo = lower_ci(estimate),
          mean  = mean(estimate),
          ci.hi = upper_ci(estimate),
          interpret = interpret_cohens_d(mean)
        ), digits = 3
)
name ci.lo mean ci.hi interpret
e1 -0.279 0.214 0.700 small
e2 -0.407 0.021 0.448 very small
h3c -0.252 0.235 0.726 small

estimate = -0.53 [-1.48, 0.4], posterior probability = 82.15%

3 Exploration: predicting ADHD diagnosis with HGF parameters

Predicting whether someone has ADHD or not based on the HGF parameters.

3.1 Model setup

# recode the order and scale the predictors
df.hgf = df.hgf %>%
  mutate(
    group = case_when(
      diagnosis == "COMP" ~ 0,
      diagnosis != "COMP" & adhd.meds.bin == "FALSE" ~ 1,
      T ~ NA
    ),
    group.meds = if_else(adhd.meds.bin == "FALSE", 
                         if_else(diagnosis == "COMP", NA, 0), 
                         1)
  ) %>% mutate(across(c(be1, be2, be3, ze, om2, om3), scale_this, .names = "s{.col}"))

kable(df.hgf %>% select(diagnosis, group, group.meds) %>% distinct(),
      caption = "Coding for the order in the Bernoulli models")
Coding for the order in the Bernoulli models
diagnosis group group.meds
BOTH NA 1
ADHD NA 1
ADHD 1 0
COMP 0 NA
BOTH 1 0
# model formula
f = brms::bf( group ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3 )
f
## group ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3
# Bernoulli
priors.bern = c(
  prior(normal(0.50, 0.50),  class = Intercept), # roughly 1:1
  prior(normal(0,    1.00),  class = b)
)

3.2 Posterior predictive checks

# fit the final model
m = brm(f,
        df.hgf, prior = priors.bern,
        family = bernoulli(link = "logit"),
        iter = iter, warmup = warm,
        backend = "cmdstanr", threads = threading(8),
        file = file.path(brms_dir, "m_hgf_bern_adhd"),
        seed = 4858
        )
rstan::check_hmc_diagnostics(m$fit)
## 
## Divergences:
## 0 of 8000 iterations ended with a divergence.
## 
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
## 
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m)
mcmc_trace(post.draws, regex_pars = "^b_",
           facet_args = list(ncol = 4)) +
  scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
  scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.

This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.

# get posterior predictions
post.pred = posterior_predict(m, ndraws = nsim)

# check the fit of the predicted data compared to the real data
p = ppc_bars(df.hgf[!is.na(df.hgf$group),]$group, post.pred) + 
  theme_bw() + theme(legend.position = "none")

annotate_figure(p, top = text_grob("Posterior predictive checks", 
                                   face = "bold", size = 14))

The overall simulated data fits reasonably well. Now that we are convinced that we can trust our model, we have a look at its estimates.

3.3 Inferences

# print a summary
summary(m)
##  Family: bernoulli 
##   Links: mu = logit 
## Formula: group ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3 
##    Data: df.hgf (Number of observations: 41) 
##   Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
##          total post-warmup draws = 8000
## 
## Regression Coefficients:
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept     0.06      0.29    -0.52     0.63 1.00    11663     5569
## sbe1          0.18      0.36    -0.53     0.91 1.00    11572     5518
## sbe2         -0.09      0.37    -0.81     0.65 1.00    10220     6430
## sbe3         -0.19      0.36    -0.89     0.51 1.00    12722     6120
## sze           0.33      0.42    -0.48     1.17 1.00    10045     5970
## som2          0.66      0.36    -0.01     1.40 1.00    11259     6381
## som3         -0.04      0.44    -0.89     0.82 1.00     9145     6332
## 
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# plot the posterior distributions
post.draws %>% 
  select(starts_with("b_") & !starts_with("b_Int")) %>%
  pivot_longer(cols = starts_with("b_"), names_to = "coef", values_to = "estimate") %>%
  mutate(
    coef = substr(coef, 3, nchar(coef)),
    coef = fct_reorder(coef, desc(estimate))
  )  %>% 
  group_by(coef) %>%
  mutate(
    cred = case_when(
      (mean(estimate) < 0 & quantile(estimate, probs = 0.975) < 0) |
        (mean(estimate) > 0 & quantile(estimate, probs = 0.025) > 0) ~ "credible",
      T ~ "not credible"
    )
  ) %>% ungroup() %>%
  ggplot(aes(x = estimate, y = coef, fill = cred)) +
  geom_vline(xintercept = 0, linetype = 'dashed') +
  ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) +
  scale_fill_manual(values = c("credible" = c_dark, "not credible" = c_light)) + 
  theme_bw() +  theme(legend.position = "bottom", legend.direction = "horizontal")

e1 = hypothesis(m, "0 > -som2", alpha = 0.025)
e1$hypothesis
##        Hypothesis  Estimate Est.Error    CI.Lower CI.Upper Evid.Ratio Post.Prob
## 1 (0)-(-som2) > 0 0.6647926 0.3619556 -0.00997184 1.401287   36.91469  0.973625
##   Star
## 1
equivalence_test(m)
## # Test for Practical Equivalence
## 
##   ROPE: [-0.18 0.18]
## 
## Parameter |        H0 | inside ROPE |           95% HDI
## -------------------------------------------------------
## Intercept | Undecided |     47.91 % |     [-0.52, 0.63]
## sbe1      | Undecided |     36.29 % |     [-0.53, 0.91]
## sbe2      | Undecided |     37.88 % |     [-0.81, 0.65]
## sbe3      | Undecided |     36.39 % |     [-0.89, 0.51]
## sze       | Undecided |     27.62 % |     [-0.48, 1.17]
## som2      | Undecided |      6.51 % | [-9.97e-03, 1.40]
## som3      | Undecided |     33.36 % |     [-0.89, 0.82]
# effect sizes
kable(post.draws %>% select(starts_with("b_s")) %>%
        pivot_longer(cols = everything(), values_to = "estimate") %>%
        group_by(name) %>%
        summarise(
          ci.lo = lower_ci(estimate),
          mean  = mean(estimate),
          ci.hi = upper_ci(estimate),
          interpret = interpret_cohens_d(mean)
        ), digits = 3
)
name ci.lo mean ci.hi interpret
b_sbe1 -0.533 0.182 0.914 very small
b_sbe2 -0.811 -0.093 0.650 very small
b_sbe3 -0.891 -0.186 0.507 very small
b_som2 -0.010 0.665 1.401 medium
b_som3 -0.892 -0.037 0.821 very small
b_sze -0.477 0.329 1.174 small

4 Exploration: predicting ADHD medication with HGF parameters

Predicting whether someone with ADHD is taking medication or not based on the HGF parameters.

4.1 Model setup

# model formula
f = brms::bf( group.meds ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3 )
f
## group.meds ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3

4.2 Posterior predictive checks

# fit the final model
m = brm(f,
        df.hgf, prior = priors.bern,
        family = bernoulli(link = "logit"),
        iter = iter, warmup = warm,
        backend = "cmdstanr", threads = threading(8),
        file = file.path(brms_dir, "m_hgf_bern_meds"),
        seed = 8428
        )
rstan::check_hmc_diagnostics(m$fit)
## 
## Divergences:
## 0 of 8000 iterations ended with a divergence.
## 
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
## 
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m)
mcmc_trace(post.draws, regex_pars = "^b_",
           facet_args = list(ncol = 4)) +
  scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
  scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.

This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.

# get posterior predictions
post.pred = posterior_predict(m, ndraws = nsim)

# check the fit of the predicted data compared to the real data
p = ppc_bars(df.hgf[!is.na(df.hgf$group.meds),]$group.meds, post.pred) + 
  theme_bw() + theme(legend.position = "none")

annotate_figure(p, top = text_grob("Posterior predictive checks", 
                                   face = "bold", size = 14))

The overall simulated data fits reasonably well. Now that we are convinced that we can trust our model, we have a look at its estimates.

4.3 Inferences

# print a summary
summary(m)
##  Family: bernoulli 
##   Links: mu = logit 
## Formula: group.meds ~ sbe1 + sbe2 + sbe3 + sze + som2 + som3 
##    Data: df.hgf (Number of observations: 44) 
##   Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
##          total post-warmup draws = 8000
## 
## Regression Coefficients:
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept     0.39      0.29    -0.17     0.96 1.00    13857     5726
## sbe1          0.19      0.35    -0.49     0.90 1.00    11114     6722
## sbe2          0.28      0.36    -0.43     1.00 1.00     9999     6075
## sbe3          0.04      0.33    -0.60     0.69 1.00    11532     6472
## sze           0.20      0.35    -0.46     0.91 1.00    11852     6181
## som2         -0.52      0.33    -1.20     0.11 1.00    10907     6430
## som3         -0.18      0.38    -0.94     0.56 1.00    11295     5957
## 
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# plot the posterior distributions
post.draws %>% 
  select(starts_with("b_") & !starts_with("b_Int")) %>%
  pivot_longer(cols = starts_with("b_"), names_to = "coef", values_to = "estimate") %>%
  mutate(
    coef = substr(coef, 3, nchar(coef)),
    coef = fct_reorder(coef, desc(estimate))
  )  %>% 
  group_by(coef) %>%
  mutate(
    cred = case_when(
      (mean(estimate) < 0 & quantile(estimate, probs = 0.975) < 0) |
        (mean(estimate) > 0 & quantile(estimate, probs = 0.025) > 0) ~ "credible",
      T ~ "not credible"
    )
  ) %>% ungroup() %>%
  ggplot(aes(x = estimate, y = coef, fill = cred)) +
  geom_vline(xintercept = 0, linetype = 'dashed') +
  ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) +
  scale_fill_manual(values = c("credible" = c_dark, "not credible" = c_light)) + 
  theme_bw() +  theme(legend.position = "bottom", legend.direction = "horizontal")

e1 = hypothesis(m, "0 > som2", alpha = 0.025)
e1$hypothesis
##       Hypothesis Estimate Est.Error   CI.Lower CI.Upper Evid.Ratio Post.Prob
## 1 (0)-(som2) > 0 0.522949 0.3330175 -0.1088378 1.203256   18.95012  0.949875
##   Star
## 1
equivalence_test(m)
## # Test for Practical Equivalence
## 
##   ROPE: [-0.18 0.18]
## 
## Parameter |        H0 | inside ROPE |       95% HDI
## ---------------------------------------------------
## Intercept | Undecided |     22.33 % | [-0.17, 0.96]
## sbe1      | Undecided |     36.91 % | [-0.49, 0.90]
## sbe2      | Undecided |     31.49 % | [-0.43, 1.00]
## sbe3      | Undecided |     44.79 % | [-0.60, 0.69]
## sze       | Undecided |     37.37 % | [-0.46, 0.91]
## som2      | Undecided |     13.16 % | [-1.20, 0.11]
## som3      | Undecided |     35.99 % | [-0.94, 0.56]
# effect sizes
kable(post.draws %>% select(starts_with("b_s")) %>%
        pivot_longer(cols = everything(), values_to = "estimate") %>%
        group_by(name) %>%
        summarise(
          ci.lo = lower_ci(estimate),
          mean  = mean(estimate),
          ci.hi = upper_ci(estimate),
          interpret = interpret_cohens_d(mean)
        ), digits = 3
)
name ci.lo mean ci.hi interpret
b_sbe1 -0.494 0.185 0.899 very small
b_sbe2 -0.430 0.276 0.998 small
b_sbe3 -0.604 0.039 0.691 very small
b_som2 -1.203 -0.523 0.109 medium
b_som3 -0.938 -0.180 0.561 very small
b_sze -0.461 0.202 0.907 small

5 Plots for HGF parameters

p = df.hgf %>%
  mutate(diagnosis = if_else(diagnosis == "BOTH", "ADHD+ASD", diagnosis)) %>%
  select(subID, diagnosis, be1, be2, be3, ze, om2, om3) %>% #
  pivot_longer(cols = c(be1, be2, be3, ze, om2, om3), 
               names_to = "parameter") %>%
  mutate(
    parameter = factor(case_match(parameter,
                           "be1" ~ "stimulus surprise",
                           "be2" ~ "precision-weighted PE",
                           "be3" ~ "phasic volatility",
                           "ze"  ~ "Sigma (decision noise)",
                           "om2" ~ "cue-outcome tonic volatility",
                           "om3" ~ "environmental tonic volatility"
                           ), levels = c("cue-outcome tonic volatility", 
                                         "environmental tonic volatility", 
                                         "stimulus surprise", 
                                         "precision-weighted PE", 
                                         "phasic volatility", 
                                         "Sigma (decision noise)"))
  ) %>%
  ggplot(aes(x = 1, y = value, fill = diagnosis, colour = diagnosis)) + #
  geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
  position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
  width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
  scale_fill_manual(values = col.grp) +
  scale_color_manual(values = col.grp) +
  facet_wrap(. ~ parameter, scales = "free", ncol = 3) +
  theme_bw() + 
  theme(legend.position = "bottom", plot.title = element_blank(), 
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        text = element_text(size = 13), axis.text.x=element_blank(), 
        axis.ticks.x=element_blank(), legend.direction = "horizontal",
        legend.title = element_blank(),
        legend.margin=margin(0,0,0,0),
        legend.box.margin=margin(-5,0,0,0))

p.a = annotate_figure(p, top = text_grob("Participant-specific HGF parameters", 
                                   face = "bold", size = 14))

ggsave("plots/FigHGF.svg", plot = p.a, units = "cm", width = 27, height = 13.5)

# include medication
df.hgf %>%
  mutate(diagnosis = if_else(diagnosis == "BOTH", "ADHD+ASD", diagnosis),
         adhd.meds.bin = case_when(adhd.meds.bin == "TRUE" ~ "medicated", 
                                   T ~ ""),
         group = paste0(diagnosis, adhd.meds.bin)) %>%
  select(subID, diagnosis, group, be1, be2, be3, ze, om2, om3) %>% #
  pivot_longer(cols = c(be1, be2, be3, ze, om2, om3), 
               names_to = "parameter") %>%
  mutate(
    parameter = factor(case_match(parameter,
                           "be1" ~ "stimulus surprise",
                           "be2" ~ "precision-weighted PE",
                           "be3" ~ "phasic volatility",
                           "ze"  ~ "Sigma (decision noise)",
                           "om2" ~ "2nd tonic volatility",
                           "om3" ~ "3rd tonic volatility"
                           ), levels = c("2nd tonic volatility", 
                                         "3rd tonic volatility", 
                                         "stimulus surprise", 
                                         "precision-weighted PE", 
                                         "phasic volatility", 
                                         "Sigma (decision noise)"))
  ) %>%
  ggplot(aes(x = diagnosis, y = value, fill = group, colour = group)) + #
  geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
  position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
  width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
  #scale_fill_manual(values = col.grp) +
  #scale_color_manual(values = col.grp) +
  facet_wrap(. ~ parameter, scales = "free", ncol = 3) +
  theme_bw() + 
  theme(legend.position = "bottom", plot.title = element_blank(), 
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        text = element_text(size = 13), axis.text.x=element_blank(), 
        axis.ticks.x=element_blank(), legend.direction = "horizontal",
        legend.title = element_blank(),
        legend.margin=margin(0,0,0,0),
        legend.box.margin=margin(-5,0,0,0))

6 Learning rate update - volatile to stable

6.1 Model setup

# model formula
f.alpha = brms::bf( value ~ diagnosis * level * change + (level + change | subID) )

# set weakly informative priors taking Lawson 2017 into consideration
priors = c(
  prior(normal(-5, 2),    class = Intercept),
  prior(normal(0.5, 0.5), class = sigma),
  prior(normal(0.5, 0.5), class = sd),
  prior(lkj(2),           class = cor),
  prior(normal(0,   1.0),   class = b) # probably big difference between levels
)

6.2 Posterior predictive checks

As the next step, we fit the model, check whether there are divergence or rhat issues, and then check whether the chains have converged.

# fit the final model
m.alpha = brm(f.alpha, family = lognormal,
            df.upd, prior = priors, seed = 6688,
            iter = iter, warmup = warm,
            backend = "cmdstanr", threads = threading(t),
            file = file.path(brms_dir, "m_hgf_alpha"),
            save_pars = save_pars(all = TRUE)
            )
rstan::check_hmc_diagnostics(m.alpha$fit)
## 
## Divergences:
## 0 of 8000 iterations ended with a divergence.
## 
## Tree depth:
## 0 of 8000 iterations saturated the maximum tree depth of 10.
## 
## Energy:
## E-BFMI indicated no pathological behavior.
# check that rhats are below 1.01
sum(brms::rhat(m.alpha) >= 1.01, na.rm = T)
## [1] 0
# check the trace plots
post.draws = as_draws_df(m.alpha)
mcmc_trace(post.draws, regex_pars = "^b_",
           facet_args = list(ncol = 3)) +
  scale_x_continuous(breaks=scales::pretty_breaks(n = 3)) +
  scale_y_continuous(breaks=scales::pretty_breaks(n = 3))
## Scale for x is already present.
## Adding another scale for x, which will replace the existing scale.

This model has no pathological behaviour with E-BFMI, no divergent sample and no rhats that are higher or equal to 1.01. Therefore, we go ahead and perform our posterior predictive checks.

# get posterior predictions
post.pred = posterior_predict(m.alpha, ndraws = nsim)

# check the fit of the predicted data compared to the real data
p1 = pp_check(m.alpha, ndraws = nsim) + 
  theme_bw() + theme(legend.position = "none") + xlim(0, 0.10)

# distributions of means compared to the real values per group
p2 = ppc_stat_grouped(df.upd$value, post.pred, df.upd$diagnosis) + 
  theme_bw() + theme(legend.position = "none")
p3 = ppc_stat_grouped(df.upd$value, post.pred, df.upd$level) + 
  theme_bw() + theme(legend.position = "none")
p4 = ppc_stat_grouped(df.upd$value, post.pred, df.upd$change) + 
  theme_bw() + theme(legend.position = "none")

p = ggarrange(p1, p2, p3, p4, ncol = 1)
annotate_figure(p, top = text_grob("Posterior predictive checks", 
                                   face = "bold", size = 14))

This model fits the data well enough.

6.3 Inferences

Now that we are convinced that we can trust our model, we have a look at its estimate and use the hypothesis function to assess our hypotheses and perform explorative tests.

# print a summary
summary(m.alpha)
##  Family: lognormal 
##   Links: mu = identity; sigma = identity 
## Formula: value ~ diagnosis * level * change + (level + change | subID) 
##    Data: df.upd (Number of observations: 264) 
##   Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
##          total post-warmup draws = 8000
## 
## Multilevel Hyperparameters:
## ~subID (Number of levels: 66) 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(Intercept)              1.11      0.12     0.90     1.35 1.00     2373
## sd(level1)                 0.82      0.10     0.63     1.03 1.00     2986
## sd(change1)                0.21      0.08     0.05     0.36 1.00     2793
## cor(Intercept,level1)      0.41      0.13     0.15     0.65 1.00     2399
## cor(Intercept,change1)     0.60      0.23     0.02     0.92 1.00     5210
## cor(level1,change1)        0.62      0.23     0.03     0.92 1.00     5018
##                        Tail_ESS
## sd(Intercept)              4179
## sd(level1)                 5020
## sd(change1)                1653
## cor(Intercept,level1)      4229
## cor(Intercept,change1)     3847
## cor(level1,change1)        3687
## 
## Regression Coefficients:
##                           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept                    -4.95      0.15    -5.24    -4.66 1.00     1627
## diagnosis1                    0.18      0.21    -0.22     0.59 1.00     1712
## diagnosis2                   -0.07      0.21    -0.48     0.33 1.00     1640
## level1                        0.33      0.12     0.10     0.56 1.00     2420
## change1                       0.77      0.07     0.64     0.90 1.00     6477
## diagnosis1:level1             0.06      0.17    -0.26     0.39 1.00     2513
## diagnosis2:level1             0.15      0.16    -0.17     0.47 1.00     2632
## diagnosis1:change1            0.05      0.10    -0.14     0.24 1.00     5609
## diagnosis2:change1            0.04      0.10    -0.15     0.23 1.00     6322
## level1:change1               -0.10      0.06    -0.22     0.03 1.00    12332
## diagnosis1:level1:change1     0.06      0.09    -0.11     0.24 1.00     7030
## diagnosis2:level1:change1    -0.07      0.09    -0.24     0.10 1.00     7296
##                           Tail_ESS
## Intercept                     2760
## diagnosis1                    2534
## diagnosis2                    2841
## level1                        3706
## change1                       5573
## diagnosis1:level1             4210
## diagnosis2:level1             4199
## diagnosis1:change1            5829
## diagnosis2:change1            5894
## level1:change1                5555
## diagnosis1:level1:change1     6185
## diagnosis2:level1:change1     5620
## 
## Further Distributional Parameters:
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     1.01      0.07     0.89     1.15 1.00     2726     4036
## 
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
# get the estimates and compute group comparisons
df.m.alpha = post.draws %>% 
  select(starts_with("b_"))

# plot the posterior distributions
df.m.alpha %>%
  select(starts_with("b_")) %>%
  pivot_longer(cols = starts_with("b_"), names_to = "coef", values_to = "estimate") %>%
  subset(!startsWith(coef, "b_Int")) %>%
  mutate(
    coef = substr(coef, 3, nchar(coef)),
    coef = str_replace_all(coef, ":", " x "),
    coef = str_replace_all(coef, "diagnosis1", "ADHD"),
    coef = str_replace_all(coef, "diagnosis2", "BOTH"),
    coef = str_replace_all(coef, "level1", "alpha2"),
    coef = str_replace_all(coef, "change1", "pre2vol"),
    coef = fct_reorder(coef, desc(estimate))
  ) %>% 
  group_by(coef) %>%
  mutate(
    cred = case_when(
      (mean(estimate) < 0 & quantile(estimate, probs = 0.975) < 0) |
        (mean(estimate) > 0 & quantile(estimate, probs = 0.025) > 0) ~ "credible",
      T ~ "not credible"
    )
  ) %>% ungroup() %>%
  ggplot(aes(x = estimate, y = coef, fill = cred)) +
  geom_vline(xintercept = 0, linetype = 'dashed') +
  ggdist::stat_halfeye(alpha = 0.7) + ylab(NULL) + theme_bw() +
  scale_fill_manual(values = c(c_dark, c_light)) + theme(legend.position = "none")

# get the design matrix to figure out how to set the contrasts
df.des = cbind(df.upd, 
               model.matrix(~ diagnosis * level * change, data = df.upd)) %>%
  ungroup() %>%
  select(-subID, -value) %>% distinct()

# H4c ADHD != COMP
h4c = hypothesis(m.alpha, "0 < 2*diagnosis1 + diagnosis2", alpha = 0.025)
h4c$hypothesis
##                          Hypothesis   Estimate Est.Error  CI.Lower  CI.Upper
## 1 (0)-(2*diagnosis1+diagnosis2) < 0 -0.2903084 0.3674689 -1.013658 0.4192193
##   Evid.Ratio Post.Prob Star
## 1   3.692082  0.786875
# Exploration: alpha3 ADHD != COMP
t(df.des %>% 
    filter(level == "alpha3" & diagnosis != "BOTH") %>%
    group_by(diagnosis) %>%
    summarise(across(where(is.numeric), ~ mean(.x))) %>%
    arrange(diagnosis) %>%
    select(where(is.numeric)) %>%
    map_df(~ diff(.x))) # COMP - ADHD
##                                  [,1]
## EDT                       -0.06595082
## (Intercept)                0.00000000
## diagnosis1                -1.00000000
## diagnosis2                -2.00000000
## level1                     0.00000000
## change1                    0.00000000
## diagnosis1:level1          1.00000000
## diagnosis2:level1          2.00000000
## diagnosis1:change1         0.00000000
## diagnosis2:change1         0.00000000
## level1:change1             0.00000000
## diagnosis1:level1:change1  0.00000000
## diagnosis2:level1:change1  0.00000000
e1 = hypothesis(m.alpha, "0 > -2*diagnosis1 - diagnosis2 + 
                               2*diagnosis1:level1 + diagnosis2:level1", alpha = 0.025)
e1$hypothesis
##                                                                 Hypothesis
## 1 (0)-(-2*diagnosis1-diagnosis2+2*diagnosis1:level1+diagnosis2:level1) > 0
##     Estimate Est.Error   CI.Lower  CI.Upper Evid.Ratio Post.Prob Star
## 1 0.00919434 0.3906923 -0.7748123 0.7637141   1.060793   0.51475
# H4c: alpha2 ADHD != COMP
t(df.des %>% 
    filter(level == "alpha2" & diagnosis != "BOTH") %>%
    group_by(diagnosis) %>%
    summarise(across(where(is.numeric), ~ mean(.x))) %>%
    arrange(diagnosis) %>%
    select(where(is.numeric)) %>%
    map_df(~ diff(.x))) # COMP - ADHD
##                                  [,1]
## EDT                       -0.06595082
## (Intercept)                0.00000000
## diagnosis1                -1.00000000
## diagnosis2                -2.00000000
## level1                     0.00000000
## change1                    0.00000000
## diagnosis1:level1         -1.00000000
## diagnosis2:level1         -2.00000000
## diagnosis1:change1         0.00000000
## diagnosis2:change1         0.00000000
## level1:change1             0.00000000
## diagnosis1:level1:change1  0.00000000
## diagnosis2:level1:change1  0.00000000
e2 = hypothesis(m.alpha, "0 > -(2*diagnosis1 + diagnosis2 + 
                                2*diagnosis1:level1 + diagnosis2:level1)", alpha = 0.025)
e2$hypothesis
##                                                                   Hypothesis
## 1 (0)-(-(2*diagnosis1+diagnosis2+2*diagnosis1:level1+diagnosis2:level1)) > 0
##    Estimate Est.Error   CI.Lower CI.Upper Evid.Ratio Post.Prob Star
## 1 0.5714225 0.5394324 -0.4726952  1.62411   6.054674   0.85825
# Explore BOTH

e3 = hypothesis(m.alpha, "0 < -(2*diagnosis2 + diagnosis1) + 
                                2*diagnosis2:level1 + diagnosis1:level1", alpha = 0.025)
e3$hypothesis
##                                                                   Hypothesis
## 1 (0)-(-(2*diagnosis2+diagnosis1)+2*diagnosis2:level1+diagnosis1:level1) < 0
##    Estimate Est.Error CI.Lower  CI.Upper Evid.Ratio Post.Prob Star
## 1 -0.323644 0.3870926 -1.08834 0.4232842   3.875076  0.794875
e4 = hypothesis(m.alpha, "0 > -(2*diagnosis2 + diagnosis1 + 
                                2*diagnosis2:level1 + diagnosis1:level1)", alpha = 0.025)
e4$hypothesis
##                                                                   Hypothesis
## 1 (0)-(-(2*diagnosis2+diagnosis1+2*diagnosis2:level1+diagnosis1:level1)) > 0
##    Estimate Est.Error   CI.Lower CI.Upper Evid.Ratio Post.Prob Star
## 1 0.4151504 0.5398817 -0.6248419 1.475844   3.535147    0.7795
# calculate effect sizes
df.effect = post.draws %>%
  mutate(across(starts_with("sd")|starts_with("sigma"), ~.^2)) %>%
  mutate(
    sumvar = sqrt(rowSums(select(., starts_with("sd")|starts_with("sigma")))),
    h4c = (2*`b_diagnosis1` + `b_diagnosis2`) / sumvar
  )

kable(df.effect %>% select(starts_with("e")|starts_with("h")) %>%
        pivot_longer(cols = everything(), values_to = "estimate") %>%
        group_by(name) %>%
        summarise(
          ci.lo = lower_ci(estimate),
          mean  = mean(estimate),
          ci.hi = upper_ci(estimate),
          interpret = interpret_cohens_d(mean)
        ), digits = 3
)
name ci.lo mean ci.hi interpret
h4c -0.244 0.168 0.59 very small

h4c ADHD vs. COMP: estimate = -0.29 [-1.01, 0.42], posterior probability = 78.69%

6.4 Check the influence of outliers

# rank transform the values
df.upd = df.upd %>% ungroup() %>%
  mutate(rvalue = rank(value))

if (!file.exists(file.path(brms_dir, "aov_alpha.rds"))) {
  aov = anovaBF(rvalue ~ diagnosis * level * change, data = df.upd)
} else {
  aov = readRDS(file.path(brms_dir, "aov_alpha.rds"))
}

kable(aov@bayesFactor %>% arrange(desc(bf)) %>%
  select(bf) %>% mutate(bf.diff = abs(lead(bf)-bf),
                        bf.int  = interpret_bf(bf.diff, log = T)), digits = 3)
bf bf.diff bf.int
level + change 24.839 0.728 anecdotal evidence in favour of
level + change + level:change 24.110 1.606 moderate evidence in favour of
change 22.504 0.484 anecdotal evidence in favour of
diagnosis + level + change 22.020 0.614 anecdotal evidence in favour of
diagnosis + level + change + level:change 21.407 0.825 anecdotal evidence in favour of
diagnosis + level + diagnosis:level + change 20.582 0.729 anecdotal evidence in favour of
diagnosis + level + diagnosis:level + change + level:change 19.852 0.135 anecdotal evidence in favour of
diagnosis + change 19.717 0.123 anecdotal evidence in favour of
diagnosis + level + change + diagnosis:change 19.595 0.648 anecdotal evidence in favour of
diagnosis + level + change + diagnosis:change + level:change 18.946 0.837 anecdotal evidence in favour of
diagnosis + level + diagnosis:level + change + diagnosis:change 18.109 0.646 anecdotal evidence in favour of
diagnosis + level + diagnosis:level + change + diagnosis:change + level:change 17.464 0.148 anecdotal evidence in favour of
diagnosis + change + diagnosis:change 17.315 2.001 moderate evidence in favour of
diagnosis + level + diagnosis:level + change + diagnosis:change + level:change + diagnosis:level:change 15.314 13.743 extreme evidence in favour of
level 1.571 2.836 strong evidence in favour of
diagnosis + level -1.265 1.575 moderate evidence in favour of
diagnosis -2.841 0.168 anecdotal evidence in favour of
diagnosis + level + diagnosis:level -3.008 NA

7 Plots for learning rate updates

# rain cloud plot
df.upd %>%
  ggplot(aes(1, value, fill = diagnosis, colour = diagnosis)) + #
  geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
  position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
  width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
  scale_fill_manual(values = col.grp) +
  scale_color_manual(values = col.grp) +
  facet_wrap(level ~ change, scales = "free") +
  labs(title = "Learning rate updates", x = "", y = "") +
  theme_bw() + 
  theme(legend.position = "bottom", plot.title = element_text(hjust = 0.5), 
        legend.direction = "horizontal", text = element_text(size = 15),
        axis.text.x = element_blank(), axis.ticks.x = element_blank())

# Exluding the outliers
df.upd %>%
  filter(value < 0.4) %>%
  ggplot(aes(1, value, fill = diagnosis, colour = diagnosis)) + #
  geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
  position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
  width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
  scale_fill_manual(values = col.grp) +
  scale_color_manual(values = col.grp) +
  facet_wrap(level ~ change, scales = "free") +
  labs(title = "Learning rate updates", x = "", y = "") +
  theme_bw() + 
  theme(legend.position = "bottom", plot.title = element_text(hjust = 0.5), 
        legend.direction = "horizontal", text = element_text(size = 15),
        axis.text.x = element_blank(), axis.ticks.x = element_blank())

df.upd %>% filter(value >= 0.4) %>% group_by(diagnosis) %>% count()
## # A tibble: 3 × 2
## # Groups:   diagnosis [3]
##   diagnosis     n
##   <fct>     <int>
## 1 ADHD          4
## 2 BOTH          3
## 3 COMP          1
# including medication
df.upd %>%   
  merge(., df.hgf %>% select(subID, adhd.meds.bin)) %>%
  mutate(diagnosis = if_else(diagnosis == "BOTH", "ADHD+ASD", diagnosis),
         adhd.meds.bin = case_when(adhd.meds.bin == "TRUE" ~ "medicated", 
                                   T ~ ""),
         group = paste0(diagnosis, adhd.meds.bin)) %>%
  ggplot(aes(diagnosis, value, fill = group, colour = group)) + #
  geom_rain(rain.side = 'r',
boxplot.args = list(color = "black", outlier.shape = NA, show.legend = FALSE, alpha = .8),
violin.args = list(color = "black", outlier.shape = NA, alpha = .8),
boxplot.args.pos = list(
  position = ggpp::position_dodgenudge(x = 0, width = 0.3), width = 0.3
),
point.args = list(show.legend = FALSE, alpha = .5),
violin.args.pos = list(
  width = 0.6, position = position_nudge(x = 0.16)),
point.args.pos = list(position = ggpp::position_dodgenudge(x = -0.25, width = 0.1))) +
  # scale_fill_manual(values = col.grp) +
  # scale_color_manual(values = col.grp) +
  facet_wrap(level ~ change, scales = "free") +
  labs(title = "Learning rate updates", x = "", y = "") +
  theme_bw() + 
  theme(legend.position = "bottom", plot.title = element_text(hjust = 0.5), 
        legend.direction = "horizontal", text = element_text(size = 15),
        axis.text.x = element_blank(), axis.ticks.x = element_blank())